import torch
from torch import nn


class Policy(torch.nn.Module):
    def __init__(self, num_action):
        super(Policy, self).__init__()
        self.layer1 = nn.Sequential(nn.Linear(28 * 28, 64),
                                    nn.ReLU(),
                                    nn.Linear(64, num_action),
                                    nn.ReLU()
                                    )
        self.layer2 = nn.Softmax(dim=1)

    def forward(self, x):
        x = x.view(1, 28 * 28)
        x = self.layer1(x)

        output = self.layer2(x)
        return output
